{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "3b28eab8-69fa-4de3-a956-fdfdfe60f57a",
   "metadata": {},
   "source": [
    "# SplitRec：在隐语中使用拆分 MMoe 算法（Tensorflow 后端）\n",
    "多任务学习的目的是希望通过学习不同任务的联系和差异，提高每个任务的学习效率和质量。\n",
    "\n",
    "MMoe 是谷歌在 KDD 2018[《Modeling Task Relationships in Multi-task Learning with Multi-gate Mixture-of-Experts》](https://dl.acm.org/doi/pdf/10.1145/3219819.3220007)提出的经典多任务模型。本文将介绍如何在隐语中使用拆分 MMoe 算法。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5ff15be2-ea78-4b1e-b56f-07e080c7dc68",
   "metadata": {},
   "source": [
    "## MMoe 模型\n",
    "多任务学习的难点在于任务之间相关度不高的情况下各任务的学习容易相互干扰，导致模型效果不佳。单个模型往往善于学习某一部分目标，而在其他目标的学习上表现不佳， MMoe 是单一全局模型和多个局部模型的这种，可以很好地解决这个问题。\n",
    "\n",
    "MMoe 由多个专家网络 Expert 和门控网络 Gate 组成，每个专家网络由门控网络控制对各个任务的贡献，同时么门控网络的控制使得各专家网络学习数据中不同领域的信息，而共享的专家网络可以有效减少模型参数，具体模型结构如下"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ce703399-5ece-479c-b0a9-dea94a6fc7a4",
   "metadata": {},
   "source": [
    "![mmoe](./resources/mmoe0.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6d3be61a-1364-40cd-baf9-4bed1120c549",
   "metadata": {},
   "source": [
    "假设训练任务数为 K 的 MMoe 模型，Expert 数为 n，其中第 k 个任务的计算公式如下：\n",
    "$$y_k=h_k(\\sum{n \\atop i=1}g_i^k(x)f_i(x))$$\n",
    "\n",
    "对于任务 k 来说，各个 Expert 计算输出 $f_i(x)$，$g_i^k(x),i=1,2,...,n$ 表示各个 Expert 被选择的概率，将各个 Expert 的输出加权求和，输出给 Tower k 进行学习。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "459d7016-c98d-4257-89dd-c16f9c736ec6",
   "metadata": {},
   "source": [
    "## 隐语中的 MMoe\n",
    "隐语中的 MMoe 考虑多任务目标在同一方的情况，双方特征经过 base 模型计算后，将 base 模型的输出输入到 fuse 模型进行多任务学习。具体模型结构如下\n",
    "\n",
    "![mmoe1](./resources/mmoe1.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e91bb7f9-5f31-4517-8b46-1d3b124a6c21",
   "metadata": {},
   "source": [
    "## 隐语封装\n",
    "我们在隐语中提供了对于各种应用的封装。 关于 MMoe 的封装在 secretflow/ml/nn/applications/sl_mmoe_tf.py，提供了 `MMoeBase` 和 `MMoeFuse` 两个类。 下面我们通过一个例子来看一下如何使用隐语封装的拆分 MMoe 来进行训练。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ca016533-0a7b-45a7-a49f-a60d8f133fe7",
   "metadata": {},
   "source": [
    "## 环境设置"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "9a49f3af-98fe-4870-9b36-c871955c605a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The version of SecretFlow: 1.1.0.dev20230926\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-09-26 19:48:38,566\tINFO worker.py:1538 -- Started a local Ray instance.\n"
     ]
    }
   ],
   "source": [
    "import secretflow as sf\n",
    "\n",
    "# Check the version of your SecretFlow\n",
    "print('The version of SecretFlow: {}'.format(sf.__version__))\n",
    "\n",
    "# In case you have a running secretflow runtime already.\n",
    "sf.shutdown()\n",
    "sf.init(['alice', 'bob', 'charlie'], address=\"local\", log_to_driver=False)\n",
    "alice, bob, charlie = sf.PYU('alice'), sf.PYU('bob'), sf.PYU('charlie')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "48e2aece-51da-494c-9869-4657bbf80956",
   "metadata": {},
   "source": [
    "## 数据集介绍\n",
    "这里我们使用来自 UCI 的 Census Income 数据集，Census Income 包含 15 个人口普查特征，这里我们用其中 13 个特征预测收入情况和婚姻状态。\n",
    "\n",
    "[数据集官网](https://archive.ics.uci.edu/dataset/20/census+income)\n",
    "\n",
    "[数据集下载](https://archive.ics.uci.edu/static/public/20/census+income.zip)\n",
    "\n",
    "这里我们对数据进行纵向切分\n",
    "\n",
    "Alice\n",
    "\n",
    "- workclass\n",
    "- fnlwgt\n",
    "- education\n",
    "- education_num\n",
    "- relationship\n",
    "- race\n",
    "- capital_gain\n",
    "- capital_loss\n",
    "- hours_per_week\n",
    "- income_50k (label_1)\n",
    "- marital_status (label_2)\n",
    "\n",
    "Bob\n",
    "\n",
    "- age\n",
    "- sex\n",
    "- occupation\n",
    "- native_country"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "abad3a23-0208-4c18-ac76-418f577ff32d",
   "metadata": {},
   "source": [
    "## 下载并处理数据"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b230ec16-d0e4-449d-893d-f64b22c5178b",
   "metadata": {},
   "source": [
    "下载数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "82c66580-bef6-4190-b528-9478a892996e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--2023-09-26 19:48:41--  https://archive.ics.uci.edu/static/public/20/census+income.zip\n",
      "Resolving archive.ics.uci.edu (archive.ics.uci.edu)... 128.195.10.252\n",
      "Connecting to archive.ics.uci.edu (archive.ics.uci.edu)|128.195.10.252|:443... connected.\n",
      "HTTP request sent, awaiting response... 200 OK\n",
      "Length: unspecified\n",
      "Saving to: ‘census+income.zip’\n",
      "\n",
      "census+income.zip       [            <=>     ] 650.11K  14.3KB/s    in 57s     \n",
      "\n",
      "2023-09-26 19:49:40 (11.4 KB/s) - ‘census+income.zip’ saved [665715]\n",
      "\n",
      "Archive:  census+income.zip\n",
      "  inflating: data_download/adult.data  \n",
      "  inflating: data_download/adult.names  \n",
      "  inflating: data_download/adult.test  \n",
      "  inflating: data_download/Index     \n",
      "  inflating: data_download/old.adult.names  \n"
     ]
    }
   ],
   "source": [
    "!mkdir data_download\n",
    "!wget https://archive.ics.uci.edu/static/public/20/census+income.zip\n",
    "!unzip -d data_download census+income.zip"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "bd038927-ccb4-4947-897d-8dcb2c4980da",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "adult.data  adult.names  adult.test  Index  old.adult.names\n"
     ]
    }
   ],
   "source": [
    "!ls data_download"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b7d52299-ee52-445d-a5bf-3d0628cb906b",
   "metadata": {},
   "source": [
    "处理数据，这里对数值类特征进行离散化处理，并把两个 label 映射到二分类。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "7d1853a3-2eae-4d9a-8ca7-9e2fdac20891",
   "metadata": {},
   "outputs": [],
   "source": [
    "gen_data_path = './data_download/test_mmoe_data_tf'\n",
    "\n",
    "\n",
    "def data_prepare():\n",
    "    import pandas as pd\n",
    "    import os\n",
    "    import shutil\n",
    "\n",
    "    column_names = [\n",
    "        'age',\n",
    "        'workclass',\n",
    "        'fnlwgt',\n",
    "        'education',\n",
    "        'education_num',\n",
    "        'marital_status',\n",
    "        'occupation',\n",
    "        'relationship',\n",
    "        'race',\n",
    "        'sex',\n",
    "        'capital_gain',\n",
    "        'capital_loss',\n",
    "        'hours_per_week',\n",
    "        'native_country',\n",
    "        'income_50k',\n",
    "    ]\n",
    "\n",
    "    train_df = pd.read_csv(\n",
    "        './data_download/adult.data',\n",
    "        delimiter=',',\n",
    "        header=None,\n",
    "        index_col=None,\n",
    "        names=column_names,\n",
    "    )\n",
    "\n",
    "    label_columns = ['income_50k', 'marital_status']\n",
    "\n",
    "    # continues feature to discrete feature\n",
    "    train_df['age'] = (train_df['age'] / 10).astype(int).astype('string')\n",
    "    train_df['fnlwgt'] = (train_df['fnlwgt'] / 10000).astype(int).astype('string')\n",
    "    train_df['education_num'] = train_df['education_num'].astype('string')\n",
    "    train_df['capital_gain'] = (\n",
    "        (train_df['capital_gain'] / 1000).astype(int).astype('string')\n",
    "    )\n",
    "    train_df['capital_loss'] = (\n",
    "        (train_df['capital_loss'] / 100).astype(int).astype('string')\n",
    "    )\n",
    "    train_df['hours_per_week'] = (\n",
    "        (train_df['hours_per_week'] / 5).astype(int).astype('string')\n",
    "    )\n",
    "\n",
    "    # label\n",
    "    train_df['income_50k'] = (train_df['income_50k'] == ' >50K').astype(int)\n",
    "    train_df['marital_status'] = (\n",
    "        train_df['marital_status'] == ' Never-married'\n",
    "    ).astype(int)\n",
    "\n",
    "    if os.path.exists(gen_data_path):\n",
    "        shutil.rmtree(gen_data_path)\n",
    "    os.mkdir(gen_data_path)\n",
    "    os.mkdir(gen_data_path + '/vocabulary')\n",
    "\n",
    "    train_df.to_csv(\n",
    "        gen_data_path + \"/train_data.csv\", index=False, sep=\"|\", encoding='utf-8'\n",
    "    )\n",
    "\n",
    "    train_data_alice = train_df[\n",
    "        [\n",
    "            'workclass',\n",
    "            'fnlwgt',\n",
    "            'education',\n",
    "            'education_num',\n",
    "            'relationship',\n",
    "            'race',\n",
    "            'capital_gain',\n",
    "            'capital_loss',\n",
    "            'hours_per_week',\n",
    "            'income_50k',\n",
    "            'marital_status',\n",
    "        ]\n",
    "    ]\n",
    "\n",
    "    train_data_bob = train_df[\n",
    "        [\n",
    "            'age',\n",
    "            'sex',\n",
    "            'occupation',\n",
    "            'native_country',\n",
    "        ]\n",
    "    ]\n",
    "\n",
    "    train_data_alice.to_csv(\n",
    "        gen_data_path + \"/train_data_alice.csv\", index=False, sep=\"|\", encoding='utf-8'\n",
    "    )\n",
    "    train_data_bob.to_csv(\n",
    "        gen_data_path + \"/train_data_bob.csv\", index=False, sep=\"|\", encoding='utf-8'\n",
    "    )\n",
    "\n",
    "    for fea in column_names:\n",
    "        if fea not in label_columns:\n",
    "            with open(gen_data_path + '/vocabulary/' + fea, 'w') as f:\n",
    "                f.write('\\n'.join(list(train_df[fea].unique())))\n",
    "\n",
    "\n",
    "data_prepare()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "981d0bc6-f0ef-433e-b021-8a70979f822f",
   "metadata": {},
   "source": [
    "到此为止我们已经完成了数据的处理和拆分，产出了 Alice Bob 两方的数据文件。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "d25fef11-2bb6-4fa3-86a5-d4061ca6b31a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "workclass|fnlwgt|education|education_num|relationship|race|capital_gain|capital_loss|hours_per_week|income_50k|marital_status\n",
      " State-gov|7| Bachelors|13| Not-in-family| White|2|0|8|0|1\n",
      " Self-emp-not-inc|8| Bachelors|13| Husband| White|0|0|2|0|0\n",
      " Private|21| HS-grad|9| Not-in-family| White|0|0|8|0|0\n",
      " Private|23| 11th|7| Husband| Black|0|0|8|0|0\n",
      " Private|33| Bachelors|13| Wife| Black|0|0|8|0|0\n",
      " Private|28| Masters|14| Wife| White|0|0|8|0|0\n",
      " Private|16| 9th|5| Not-in-family| Black|0|0|3|0|0\n",
      " Self-emp-not-inc|20| HS-grad|9| Husband| White|0|0|9|1|0\n",
      " Private|4| Masters|14| Not-in-family| White|14|0|10|1|1\n"
     ]
    }
   ],
   "source": [
    "!head data_download/test_mmoe_data_tf/train_data_alice.csv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "ac9e06ab-d544-4ddd-ac04-013867ad0b1c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "age|sex|occupation|native_country\n",
      "3| Male| Adm-clerical| United-States\n",
      "5| Male| Exec-managerial| United-States\n",
      "3| Male| Handlers-cleaners| United-States\n",
      "5| Male| Handlers-cleaners| United-States\n",
      "2| Female| Prof-specialty| Cuba\n",
      "3| Female| Exec-managerial| United-States\n",
      "4| Female| Other-service| Jamaica\n",
      "5| Male| Exec-managerial| United-States\n",
      "3| Female| Prof-specialty| United-States\n"
     ]
    }
   ],
   "source": [
    "!head data_download/test_mmoe_data_tf/train_data_bob.csv"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c1727392-bc45-4ef8-b8a9-5c17a4b278fe",
   "metadata": {},
   "source": [
    "## 构造 data_builder 读数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "389de9a5-e79a-413f-98d8-7981951853dc",
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_dataset_builder_bob(\n",
    "    batch_size=128,\n",
    "    repeat_count=5,\n",
    "):\n",
    "    def dataset_builder(x):\n",
    "        import pandas as pd\n",
    "        import tensorflow as tf\n",
    "\n",
    "        x = [dict(t) if isinstance(t, pd.DataFrame) else t for t in x]\n",
    "        x = x[0] if len(x) == 1 else tuple(x)\n",
    "        data_set = (\n",
    "            tf.data.Dataset.from_tensor_slices(x).batch(batch_size).repeat(repeat_count)\n",
    "        )\n",
    "\n",
    "        return data_set\n",
    "\n",
    "    return dataset_builder\n",
    "\n",
    "\n",
    "def create_dataset_builder_alice(\n",
    "    batch_size=128,\n",
    "    repeat_count=5,\n",
    "):\n",
    "    def _parse_alice(row_sample, label):\n",
    "        import tensorflow as tf\n",
    "\n",
    "        y_1 = label[\"income_50k\"]\n",
    "        y_2 = label[\"marital_status\"]\n",
    "        return row_sample, (y_1, y_2)\n",
    "\n",
    "    def dataset_builder(x):\n",
    "        import pandas as pd\n",
    "        import tensorflow as tf\n",
    "\n",
    "        x = [dict(t) if isinstance(t, pd.DataFrame) else t for t in x]\n",
    "        x = x[0] if len(x) == 1 else tuple(x)\n",
    "        data_set = (\n",
    "            tf.data.Dataset.from_tensor_slices(x).batch(batch_size).repeat(repeat_count)\n",
    "        )\n",
    "\n",
    "        data_set = data_set.map(_parse_alice)\n",
    "\n",
    "        return data_set\n",
    "\n",
    "    return dataset_builder\n",
    "\n",
    "\n",
    "bs = 128\n",
    "epoch = 1\n",
    "\n",
    "data_builder_dict = {\n",
    "    alice: create_dataset_builder_alice(\n",
    "        batch_size=bs,\n",
    "        repeat_count=epoch,\n",
    "    ),\n",
    "    bob: create_dataset_builder_bob(\n",
    "        batch_size=bs,\n",
    "        repeat_count=epoch,\n",
    "    ),\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3bb79f33-cb1b-4407-b4fe-7482bf91a10a",
   "metadata": {},
   "source": [
    "## 定义模型结构"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "4f79f19d-89ab-4904-b708-961d6246409e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from secretflow.ml.nn.applications.sl_mmoe_tf import MMoEBase, MMoEFuse\n",
    "\n",
    "\n",
    "def create_base_model_alice():\n",
    "    # Create model\n",
    "    def create_model():\n",
    "        fea_list = [\n",
    "            'workclass',\n",
    "            'fnlwgt',\n",
    "            'education',\n",
    "            'education_num',\n",
    "            'relationship',\n",
    "            'race',\n",
    "            'capital_gain',\n",
    "            'capital_loss',\n",
    "            'hours_per_week',\n",
    "        ]\n",
    "        vocab_dict = {}\n",
    "        for fea in fea_list:\n",
    "            with open(gen_data_path + '/vocabulary/' + fea) as f:\n",
    "                vocab_dict[fea] = [line.strip() for line in f.readlines()]\n",
    "\n",
    "        def preprocess():\n",
    "            inputs = {\n",
    "                fea: tf.keras.Input(shape=(1,), dtype=tf.string) for fea in fea_list\n",
    "            }\n",
    "            outputs = {\n",
    "                fea: tf.keras.layers.StringLookup(\n",
    "                    vocabulary=vocab_dict[fea], output_mode=\"one_hot\"\n",
    "                )(inputs[fea])\n",
    "                for fea in fea_list\n",
    "            }\n",
    "\n",
    "            return tf.keras.Model(inputs=inputs, outputs=outputs)\n",
    "\n",
    "        preprocess_layer = preprocess()\n",
    "        model = MMoEBase(\n",
    "            dnn_units_size=[32],\n",
    "            preprocess_layer=preprocess_layer,\n",
    "            embedding_dim=9,\n",
    "        )\n",
    "        model.compile(\n",
    "            loss=tf.keras.losses.binary_crossentropy,\n",
    "            optimizer=tf.keras.optimizers.Adam(),\n",
    "            metrics=[\n",
    "                tf.keras.metrics.AUC(),\n",
    "                tf.keras.metrics.Precision(),\n",
    "                tf.keras.metrics.Recall(),\n",
    "            ],\n",
    "        )\n",
    "        return model  # need wrap\n",
    "\n",
    "    return create_model\n",
    "\n",
    "\n",
    "def create_base_model_bob():\n",
    "    # Create model\n",
    "    def create_model():\n",
    "        fea_list = [\n",
    "            'age',\n",
    "            'sex',\n",
    "            'occupation',\n",
    "            'native_country',\n",
    "        ]\n",
    "        vocab_dict = {}\n",
    "        for fea in fea_list:\n",
    "            with open(gen_data_path + '/vocabulary/' + fea) as f:\n",
    "                vocab_dict[fea] = [line.strip() for line in f.readlines()]\n",
    "\n",
    "        def preprocess():\n",
    "            inputs = {\n",
    "                fea: tf.keras.Input(shape=(1,), dtype=tf.string) for fea in fea_list\n",
    "            }\n",
    "            outputs = {\n",
    "                fea: tf.keras.layers.StringLookup(\n",
    "                    vocabulary=vocab_dict[fea], output_mode=\"one_hot\"\n",
    "                )(inputs[fea])\n",
    "                for fea in fea_list\n",
    "            }\n",
    "\n",
    "            return tf.keras.Model(inputs=inputs, outputs=outputs)\n",
    "\n",
    "        preprocess_layer = preprocess()\n",
    "\n",
    "        model = MMoEBase(\n",
    "            dnn_units_size=[20],\n",
    "            preprocess_layer=preprocess_layer,\n",
    "        )\n",
    "        model.compile(\n",
    "            loss=tf.keras.losses.binary_crossentropy,\n",
    "            optimizer=tf.keras.optimizers.Adam(),\n",
    "            metrics=[\n",
    "                tf.keras.metrics.AUC(),\n",
    "                tf.keras.metrics.Precision(),\n",
    "                tf.keras.metrics.Recall(),\n",
    "            ],\n",
    "        )\n",
    "        return model  # need wrap\n",
    "\n",
    "    return create_model\n",
    "\n",
    "\n",
    "def create_fuse_model():\n",
    "    # Create model\n",
    "    def create_model():\n",
    "        model = MMoEFuse(\n",
    "            num_experts=3,\n",
    "            expert_units_size=[32, 16],\n",
    "            expert_activation=\"relu\",\n",
    "            num_tasks=2,\n",
    "            gate_units_size=[],\n",
    "            gate_activation=\"\",\n",
    "            tower_units_size=[12],\n",
    "            tower_activation=\"relu\",\n",
    "            output_activation=[\"sigmoid\", \"sigmoid\"],\n",
    "        )\n",
    "        model.compile(\n",
    "            loss={\"output_1\": \"binary_crossentropy\", \"output_2\": \"binary_crossentropy\"},\n",
    "            loss_weights=[1.0, 1.0],\n",
    "            optimizer=tf.keras.optimizers.Adam(),\n",
    "            metrics={\n",
    "                \"output_1\": [\n",
    "                    tf.keras.metrics.AUC(),\n",
    "                    tf.keras.metrics.Precision(),\n",
    "                    tf.keras.metrics.Recall(),\n",
    "                ],\n",
    "                \"output_2\": [\n",
    "                    tf.keras.metrics.AUC(),\n",
    "                    tf.keras.metrics.Precision(),\n",
    "                    tf.keras.metrics.Recall(),\n",
    "                ],\n",
    "            },\n",
    "        )\n",
    "        return model\n",
    "\n",
    "    return create_model\n",
    "\n",
    "\n",
    "model_base_alice = create_base_model_alice()\n",
    "model_base_bob = create_base_model_bob()\n",
    "base_model_dict = {\n",
    "    alice: model_base_alice,\n",
    "    bob: model_base_bob,\n",
    "}\n",
    "model_fuse = create_fuse_model()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5749db77-42a7-4594-b89e-d62e37a39c7e",
   "metadata": {},
   "source": [
    "## 定义 SL Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "4a516ba4-56dc-4add-8074-87f1a4b2d766",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:root:Create proxy actor <class 'secretflow.ml.nn.sl.backend.tensorflow.sl_base.PYUSLTFModel'> with party alice.\n",
      "INFO:root:Create proxy actor <class 'secretflow.ml.nn.sl.backend.tensorflow.sl_base.PYUSLTFModel'> with party bob.\n"
     ]
    }
   ],
   "source": [
    "from secretflow.ml.nn import SLModel\n",
    "\n",
    "device_y = alice\n",
    "\n",
    "sl_model = SLModel(\n",
    "    base_model_dict=base_model_dict,\n",
    "    device_y=device_y,\n",
    "    model_fuse=model_fuse,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cdea7580-d127-4a7f-8896-d1e94732aefb",
   "metadata": {},
   "source": [
    "## 开始训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "1cac5529-e74d-4b0d-b141-b3cffc697335",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:root:SL Train Params: {'x': VDataFrame(partitions={PYURuntime(alice): <secretflow.data.partition.pandas.partition.PdPartition object at 0x7f046cf3bd90>, PYURuntime(bob): <secretflow.data.partition.pandas.partition.PdPartition object at 0x7f05487bef70>}, aligned=True), 'y': VDataFrame(partitions={PYURuntime(alice): <secretflow.data.partition.pandas.partition.PdPartition object at 0x7f046cf4cbe0>}, aligned=True), 'batch_size': 128, 'epochs': 1, 'verbose': 1, 'callbacks': None, 'validation_data': None, 'shuffle': False, 'sample_weight': None, 'validation_freq': 1, 'dp_spent_step_freq': None, 'dataset_builder': {PYURuntime(alice): <function create_dataset_builder_alice.<locals>.dataset_builder at 0x7f0548830430>, PYURuntime(bob): <function create_dataset_builder_bob.<locals>.dataset_builder at 0x7f05488304c0>}, 'audit_log_params': {}, 'random_seed': 1234, 'audit_log_dir': None, 'self': <secretflow.ml.nn.sl.sl_model.SLModel object at 0x7f05487bea90>}\n",
      "100%|██████████| 255/255 [00:33<00:00,  7.72it/s, epoch: 1/1 -  train_loss:0.963053286075592  train_output_1_loss:0.45116713643074036  train_output_2_loss:0.5118862390518188  train_output_1_auc_1:0.7793986201286316  train_output_1_precision_1:0.7068676948547363  train_output_1_recall_1:0.26909834146499634  train_output_2_auc_2:0.7872133851051331  train_output_2_precision_2:0.740103542804718  train_output_2_recall_2:0.4147711396217346 ]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'train_loss': [0.9630533], 'train_output_1_loss': [0.45116714], 'train_output_2_loss': [0.51188624], 'train_output_1_auc_1': [0.7793986], 'train_output_1_precision_1': [0.7068677], 'train_output_1_recall_1': [0.26909834], 'train_output_2_auc_2': [0.7872134], 'train_output_2_precision_2': [0.74010354], 'train_output_2_recall_2': [0.41477114]}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "from secretflow.data.vertical import read_csv\n",
    "\n",
    "vdf = read_csv(\n",
    "    {\n",
    "        alice: gen_data_path + '/train_data_alice.csv',\n",
    "        bob: gen_data_path + '/train_data_bob.csv',\n",
    "    },\n",
    "    delimiter='|',\n",
    ")\n",
    "\n",
    "int_fea_list = [\n",
    "    'age',\n",
    "    'fnlwgt',\n",
    "    'education_num',\n",
    "    'capital_gain',\n",
    "    'capital_loss',\n",
    "    'hours_per_week',\n",
    "]\n",
    "for fea in int_fea_list:\n",
    "    vdf[fea] = vdf[fea].astype('string')\n",
    "\n",
    "label = vdf['income_50k', 'marital_status']\n",
    "data = vdf.drop(columns=['income_50k', 'marital_status'])\n",
    "\n",
    "device_y = alice\n",
    "\n",
    "history = sl_model.fit(\n",
    "    data,\n",
    "    label,\n",
    "    epochs=epoch,\n",
    "    batch_size=bs,\n",
    "    random_seed=1234,\n",
    "    dataset_builder=data_builder_dict,\n",
    ")\n",
    "print(history)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "872e6ec5-0937-4135-a4a0-2cce21a93d23",
   "metadata": {},
   "source": [
    "## 总结\n",
    "本文通过 MovieLens 数据集上的推荐任务来演示了如何通过隐语来实现 MMoe。\n",
    "\n",
    "您需要：\n",
    "\n",
    "1. 下载并拆分数据集；\n",
    "\n",
    "2. 定义处理数据的 data_builder；\n",
    "\n",
    "3. 定义好数据预处理 preprocessing layer，调用 `MMoeBase` 和 `MMoeFuse` 来定义模型结构；\n",
    "\n",
    "4. 使用 SLModel 进行训练即可。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5c57262f-b727-45bc-bfe2-9d99329b2c42",
   "metadata": {},
   "source": [
    "您可以在自己的数据集上进行尝试，如有任何问题，可以在 github 上进行讨论。"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
